{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torch.autograd as autograd\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torch.utils.data as Data\n",
    "torch.manual_seed(8) # for reproduce\n",
    "\n",
    "import time\n",
    "import numpy as np\n",
    "import gc\n",
    "import sys\n",
    "sys.setrecursionlimit(50000)\n",
    "import pickle\n",
    "torch.backends.cudnn.benchmark = True\n",
    "torch.set_default_tensor_type('torch.cuda.FloatTensor')\n",
    "# from tensorboardX import SummaryWriter\n",
    "torch.nn.Module.dump_patches = True\n",
    "import copy\n",
    "import pandas as pd\n",
    "#then import my own modules\n",
    "from AttentiveFP import Fingerprint, Fingerprint_viz, save_smiles_dicts, get_smiles_dicts, get_smiles_array, moltosvg_highlight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rdkit import Chem\n",
    "# from rdkit.Chem import AllChem\n",
    "from rdkit.Chem import QED\n",
    "from rdkit.Chem import rdMolDescriptors, MolSurf\n",
    "from rdkit.Chem.Draw import SimilarityMaps\n",
    "from rdkit import Chem\n",
    "from rdkit.Chem import AllChem\n",
    "from rdkit.Chem import rdDepictor\n",
    "from rdkit.Chem.Draw import rdMolDraw2D\n",
    "%matplotlib inline\n",
    "from numpy.polynomial.polynomial import polyfit\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import gridspec\n",
    "import matplotlib.cm as cm\n",
    "import matplotlib\n",
    "import seaborn as sns; sns.set_style(\"darkgrid\")\n",
    "from IPython.display import SVG, display\n",
    "import sascorer\n",
    "import itertools\n",
    "from sklearn.metrics import r2_score\n",
    "import scipy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_seed = 888 # 69, 88\n",
    "start_time = str(time.ctime()).replace(':','-').replace(' ','_')\n",
    "\n",
    "batch_size = 200\n",
    "epochs = 200\n",
    "\n",
    "p_dropout= 0.2\n",
    "fingerprint_dim = 200\n",
    "\n",
    "weight_decay = 5 # also known as l2_regularization_lambda\n",
    "learning_rate = 2.5\n",
    "output_units_num = 1 # for regression model\n",
    "radius = 2\n",
    "T = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of all smiles:  4200\n",
      "number of successfully processed smiles:  4200\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAU8AAAC/CAYAAAB+KF5fAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAF4RJREFUeJzt3X9M03f+B/BnCyoqhXLYsG+oA6SWrC7DHCBukhDAIiFuipJsXg5njsXdJt6I1As5opdccpk6YidiQ9Azu+ltt+TkRzwMDuTO0+F0mLmozCsCWcgSEJXCpMoB7fePpZ/5Gb/at9BSeD6Sy8H78/p8+n7n4558fr6rcDqdThARkUeUvu4AEZE/YngSEQlgeBIRCWB4EhEJYHgSEQlgeBIRCWB4EhEJYHgSEQlgeBIRCWB4EhEJYHgSEQlgeBIRCWB4EhEJCPR1B6ZDX98gHI65NTlUeHgwHjx45OtuzDiOc+7w1zEqlQqEhS31eL05EZ4Oh3POhSeAOTmm8XCcc8d8GKMLT9uJiAQwPImIBDA8iYgEMDyJiATMiRtGc92IAxgaHpm0ZtGCQATyTyGR1zA8/cDQ8Ai++rZn0pqkFyIQuIi7k8hbeKxCRCSA4UlEJIDhSUQkgOFJRCSA4UlEJIDhSUQkgOFJRCSA4UlEJIDhSUQkgOFJRCSA4UlEJMCtl6GLi4tRXV094fLLly9Do9EgLy8P165dG7M8OzsbZrNZ1jY4OAiz2Yz6+noMDAxAp9Nh165dyMjI8HAIRETe51Z4vvvuu3jjjTdkbSMjI8jPz0dcXBw0Go3UHh0djYMHD8pqw8LCxmyzoKAAra2tMJlM0Gq1qK6uRkFBASoqKpCamioyFiIir3ErPJ9//nk8//zzsrbPP/8cT548QW5urqw9KCgIq1evnnR7Fy9eRHNzM8rLy2E0GgEAa9euRVdXFw4cOMDwJKJZT/ia55kzZ7B48WJkZ2d7vG5DQwNUKpXsFF2hUCAnJwcdHR24e/euaLeIiLxCKDzv3buHS5cuYcOGDQgODpYt6+zsRFJSEgwGAzIzM2GxWDA8PCyraWtrg06ng1Ip//i4uDgAgNVqFekWEZHXCM2eW1NTg9HR0TGn7AkJCcjOzsaKFStgt9vR2NiIsrIy3L59G8eOHZPqbDYboqOjx2w3NDRUWu6J8PDgqYv8kEajAgA4H9qhCg6atHbJkkXQ/GKJN7o17VzjnOvmwzjnwxhdhMKzqqoKUVFRSEpKkrUXFhbKfk9LS8OyZctQUVGBlpYWJCYmSssUCsWE259s2XgePHg0574vWqNRobf3BwCAfWgEPzx6Mmm93T6E3tFRb3RtWj09zrlsPozTX8eoVCqEDsA8Pm1vaWlBZ2cntmzZ4lb95s2bAQA3btyQ2tRq9bhHl/39/QB+OgIlIpqtPD7yPHPmDAICApCTk+NWvcPhAADZ9U2dTofPP/8cDodD1u661qnX6z3tlt+a6MvdnA/tsA/92D7HDqqJ5gSPwtNut6O+vh4pKSmIiIhwa53a2loAQHx8vNRmNBrxj3/8A01NTVi/fr3UXlNTg5iYGOh0Ok+65dcm+nI3VXCQdKoer9eMWU5EvuVReJ47dw52ux1bt24ds6ylpQWVlZXIzMxEZGQk7HY7Lly4gKqqKmRlZSEhIUGqTU1NRXJyMkpKSmCz2aDValFTU4Pr16/DYrE8+6iIiGaYR+FZVVWFsLAwpKenj1nmesuorKwMfX19UCqViImJQXFxMfLy8mS1CoUCFosFhw8fhtlsll7PLC8vH3fbRESzjcLpdPr9FTV/vts+OOTeafs31t5Jt5P0QgSW+uH3tvvrHVpPzYdx+usYvXa3nYiIGJ5EREIYnkREAhieREQCGJ5ERAIYnkREAhieREQCGJ5ERAIYnkREAhieREQC/O99PhqXQqnA4NDYqe2etmhBIAL555JoWjA854ih4VG33n8P9MP334lmIx6HEBEJYHgSEQlgeBIRCWB4EhEJYHgSEQlgeBIRCeBzKzNooq8VfpqffnsI0bzH8JxBE32t8NP4tcJE/omn7UREAhieREQCGJ5ERAIYnkREAhieREQCGJ5ERAL4qNI8wjk/iabPlOF59epVbN++fdxl586dQ2xsrPT7F198gSNHjuDOnTtYunQpjEYjTCYTQkJCZOsNDg7CbDajvr4eAwMD0Ol02LVrFzIyMp5xODQZzvlJNH3c/q/EZDIhKSlJ1qbVaqWfr169ip07dyIjIwOFhYW4d+8eSktLYbVa8cknn0Cp/OlwpqCgAK2trTCZTNBqtaiurkZBQQEqKiqQmpo6DcMiIppZbodnTEwMVq9ePeHyDz74ACtXrsSHH34oBaVGo8FvfvMb1NfXIzs7GwBw8eJFNDc3o7y8HEajEQCwdu1adHV14cCBAwxPIvIL03J1q6enBzdv3sSmTZtkR5jr1q1DREQEzp8/L7U1NDRApVLJTtEVCgVycnLQ0dGBu3fvTkeXiIhmlNvhuX//fhgMBiQkJODtt9/GrVu3pGVWqxUAsHLlyjHr6fV6tLW1Sb+3tbVBp9PJQhYA4uLiZNsiIprNpjxtV6lUePPNN7FmzRqo1Wq0t7ejsrIS27Ztw+nTpxEfHw+bzQYACA0NHbN+aGgoWltbpd9tNhuio6PHrXMtJyKa7aYMT4PBAIPBIP2emJiI9PR0bNy4EWazGR999JG0TKFQjLuNn7dPVDfVsomEhwd7vI43OB/aoQoOmrRmwYLACWtc7ZPVuLMdT2qWLFkEzS+WTFoz3TQalVc/z1fmwzjnwxhdhJ5J0Wg0SElJQVNTEwBArVYDGP+osb+/X3ZEqlarJ6wDxj96ncqDB4/gmIUTY9qHRvDDoyeT1gwPj1+jCg6S2ieqcWc7ntbY7UPoHR2dtGY6aTQq9Pb+4LXP85X5ME5/HaNSqRA6ABO+YeRwOKSfXdc6n7626WK1WmXXQnU6Hdrb22Xru+qAH6+REhHNdkLh2dvbi+bmZunRpeeeew4vvvgizp49KwvFK1euoKenB5mZmVKb0WjEwMCAdNTqUlNTg5iYGOh0OpEuERF51ZSn7UVFRVi+fDlWrVqFkJAQdHR04Pjx43jy5An27Nkj1ZlMJuTn52PPnj14/fXX0dPTg9LSUsTHxyMrK0uqS01NRXJyMkpKSmCz2aDValFTU4Pr16/DYrHMzCiJiKbZlOEZFxeHuro6nD59Go8fP4ZarcaaNWvwzjvvyE6xX375ZVRUVODo0aPYuXMnli5divXr12Pv3r0ICAiQ6hQKBSwWCw4fPgyz2Sy9nlleXo709PSZGSUR0TRTOJ3O2XenxUOz9YbR4JB732E03vvmT98wmqjGne14WpP0QgSWevHddn+9yeCp+TBOfx2j128YERHNZwxPIiIBDE8iIgEMTyIiAQxPIiIBDE8iIgEMTyIiAfyyGkEjDmBoePIvU5uFj54S0TRheAoaGnbvAXgimpt42k5EJIDhSUQkgOFJRCSA4UlEJIDhSUQkgOFJRCSA4UlEJIDhSUQkgOFJRCSA4UlEJIDhSUQkgOFJRCSA4UlEJIDhSUQkgOFJRCSA4UlEJIDhSUQkgOFJRCSA4UlEJGDK7zC6cuUKamtr8fXXX6O7uxuhoaF46aWXsHv3bsTFxUl1eXl5uHbt2pj1s7OzYTabZW2Dg4Mwm82or6/HwMAAdDoddu3ahYyMjGkYEhHRzJsyPD/99FPYbDbs2LEDsbGxuH//Pk6cOIHc3FycOnUKq1evlmqjo6Nx8OBB2fphYWFjtllQUIDW1laYTCZotVpUV1ejoKAAFRUVSE1NnYZhERHNrCnD849//CPCw8NlbSkpKcjIyMBf/vIXHD16VGoPCgqShel4Ll68iObmZpSXl8NoNAIA1q5di66uLhw4cIDhSUR+Ycprnj8PTgAICQlBVFQUuru7Pf7AhoYGqFQq2Sm6QqFATk4OOjo6cPfuXY+3SUTkbUI3jB4+fIi2tjasXLlS1t7Z2YmkpCQYDAZkZmbCYrFgeHhYVtPW1gadTgelUv7RruunVqtVpEtERF415Wn7zzmdTuzbtw8OhwP5+flSe0JCArKzs7FixQrY7XY0NjairKwMt2/fxrFjx6Q6m82G6OjoMdsNDQ2VlnsqPDzY43WelfOhHargoElrFiwIfKYaV/uzbseTmiVLFkHziyWT1kw3jUbl1c/zlfkwzvkwRhePw/PQoUNobGzE+++/j9jYWKm9sLBQVpeWloZly5ahoqICLS0tSExMlJYpFIoJtz/Zsok8ePAIDofT4/WehX1oBD88ejJpzfCweI0qOEhqf5bteFpjtw+hd3R00prppNGo0Nv7g9c+z1fmwzj9dYxKpULoAMyj03az2YyTJ0+ipKQEW7ZsmbJ+8+bNAIAbN25IbWq1etyjy/7+fgA/HYESEc1mbofnkSNHUFFRgb1792L79u1ureNwOH78kKeub+p0OrS3t0vLXFzXOvV6vbtdIiLyGbfCs7y8HBaLBe+99x7eeusttzdeW1sLAIiPj5fajEYjBgYG0NTUJKutqalBTEwMdDqd29snIvKVKa95njx5EkePHkVaWhpeeeUV2Sn4woULYTAY0NLSgsrKSmRmZiIyMhJ2ux0XLlxAVVUVsrKykJCQIK2TmpqK5ORklJSUwGazQavVoqamBtevX4fFYpmZURIRTbMpw/Nf//qX9P+un10iIyPR1NQEjUYDACgrK0NfXx+USiViYmJQXFyMvLw82ToKhQIWiwWHDx+G2WyWXs8sLy9Henr6dI2LiGhGTRmep06dmnIjUVFRqKysdPtDg4ODsX//fuzfv9/tdYiIZhPOqkREJIDhSUQkgOFJRCSA4UlEJIDhSUQkgOFJRCTA44lBaG5TKBUYHBqZtGbRgkAE8s8uzXMMT5IZGh7FN9beSWuSXohA4CL+06H5jccPREQCGJ5ERAIYnkREAnjhagIjDmBoeOIbJ16euJ6IZhmG5wSGhkfw1bc9Ey6P12u82Bsimm142k5EJIDhSUQkgOFJRCSA4UlEJIDhSUQkgOFJRCSA4UlEJIDhSUQkgOFJRCSA4UlEJICvZ5LHOGEyEcOTBHDCZCKethMRCWF4EhEJ8Nl51eDgIMxmM+rr6zEwMACdToddu3YhIyNjxj97qrk6Ac7XSUST81l4FhQUoLW1FSaTCVqtFtXV1SgoKEBFRQVSU1Nn9LOnmqsT4HydRDQ5n4TnxYsX0dzcjPLychiNRgDA2rVr0dXVhQMHDsx4eBIRPSufXPNsaGiASqWSnaIrFArk5OSgo6MDd+/e9UW3aBq5Hmea7H8jDl/3kkicT44829raoNPpoFTKszsuLg4AYLVaodPp3N6eUqnw6PMDA5RYErTgmWqmYxuT1SxeFIjRkQVe+ayZqBl1OPFt58NJa+L1Gty3PcbQBCm6MDAAAXPolqan/079kT+OUbTPPglPm82G6OjoMe2hoaHSck+EhS31uA/a/wudsmaFNuyZlrNm6hoAgHqxe3V+Ljw82NddmHHzYYwuPvu7rlBMnPaTLSMimg18Ep5qtXrco8v+/n4APx2BEhHNVj4JT51Oh/b2djgc8mtdVqsVAKDX633RLSIit/kkPI1GIwYGBtDU1CRrr6mpQUxMjEc3i4iIfMEnN4xSU1ORnJyMkpIS2Gw2aLVa1NTU4Pr167BYLL7oEhGRRxROp9MnLyI+evQIhw8fxvnz52WvZ65fv94X3SEi8ojPwpOIyJ/NoUeQiYi8h+FJRCSAU337wJUrV1BbW4uvv/4a3d3dCA0NxUsvvYTdu3dLr6gCQF5eHq5duzZm/ezsbJjNZm92WcjVq1exffv2cZedO3cOsbGx0u9ffPEFjhw5gjt37mDp0qUwGo0wmUwICQnxVneFFBcXo7q6esLlly9fhkaj8at92d3djRMnTuD27du4c+cO7HY7Pv74YyQnJ4+pPXv2LI4fP47Ozk6EhYXhtddew+7du7Fo0SJZ3f379/HBBx/g3//+N4aGhmAwGGAymfDLX/7SW8OadgxPH/j0009hs9mwY8cOxMbG4v79+zhx4gRyc3Nx6tQprF69WqqNjo7GwYMHZeuHhbn52uMsYTKZkJSUJGvTarXSz1evXsXOnTuRkZGBwsJC3Lt3D6WlpbBarfjkk0/GzIEwm7z77rt44403ZG0jIyPIz89HXFwcNJqfpjb0l3353Xffoa6uDgaDAWvXrh3zSKFLbW0tfv/732Pbtm34wx/+gPb2dpSWluL777+X/UEYGhrCjh07YLfbsW/fPqjVavz1r3/Fjh078Pe//x0Gg8FbQ5teTvK6+/fvj2nr7+93JiYmOgsKCqS2X//6187XXnvNm12bVl9++aVTr9c7GxoaJq3bunWrc9OmTc7R0VGp7fLly069Xu+sq6ub6W5Ou/Pnzzv1er3zs88+k9r8aV8+vR8aGhqcer3e+eWXX8pqRkZGnOvWrXP+9re/lbV/9tlnTr1e77xx44bUdvr0aader3feunVLahsaGnKmp6c78/PzZ2gUM2/2/kmfw8LDw8e0hYSEICoqCt3d3T7oke/09PTg5s2b2LRpk+wIc926dYiIiMD58+d92DsxZ86cweLFi5Gdne3rrghx50j/xo0b6O3tRU5Ojqz91VdfxYIFC2T7rbGxEXq9HqtWrZLaFi5ciI0bN6K5uRmPHj2avs57EcNzlnj48CHa2tqwcuVKWXtnZyeSkpJgMBiQmZkJi8WC4eFhH/VSzP79+2EwGJCQkIC3334bt27dkpa5Xsn9+biBH1/TbWtr81o/p8O9e/dw6dIlbNiwAcHB8hmG5sK+dHHtl5/vt8WLF2P58uWy/dbW1jbuK9dxcXEYHR1FR0fHzHZ2hvCa5yzgdDqxb98+OBwO5OfnS+0JCQnIzs7GihUrYLfb0djYiLKyMty+fRvHjh3zYY/do1Kp8Oabb2LNmjVQq9Vob29HZWUltm3bhtOnTyM+Pl6aIGa8yWBCQ0PR2trq7W4/k5qaGoyOjiI3N1fW7u/78uem2m9PT/xjs9kmrAOAvr6+GerlzGJ4zgKHDh1CY2Mj3n//fdkd6MLCQlldWloali1bhoqKCrS0tCAxMdHbXfWIwWCQ3QxITExEeno6Nm7cCLPZjI8++khaNtE0hP42PWFVVRWioqLG3CDz9305EXf321ycgpKn7T5mNptx8uRJlJSUYMuWLVPWb968GcCP15z8kUajQUpKCr755hsAP05PCIw/AXZ/f79fTU/Y0tKCzs5Ot/Yj4N/70pP9NtUUlK5t+RuGpw8dOXIEFRUV2Lt374TPQ/6caxq/2fz4zlSenorQdc1svGubVqt13Guhs9WZM2cQEBAw5ibKRPx5X7pmPvv5fnv8+DG6urpk+02n00nXtp/23//+FwEBAVixYsXMdnaG+N9emyPKy8thsVjw3nvv4a233nJ7vdraWgBAfHz8THVtRvX29qK5uVl6lvW5557Diy++iLNnz8pC9cqVK+jp6UFmZqavuuoRu92O+vp6pKSkICIiwq11/Hlfrl69GhqNRhqDyz//+U8MDw/L9pvRaITVasW3334rtf3vf/9DXV0dXn755TE31vwFr3n6wMmTJ3H06FGkpaXhlVdekZ22LVy4EAaDAS0tLaisrERmZiYiIyNht9tx4cIFVFVVISsrCwkJCT4cgXuKioqwfPlyrFq1CiEhIejo6MDx48fx5MkT7NmzR6ozmUzIz8/Hnj178Prrr6OnpwelpaWIj49HVlaWD0fgvnPnzsFut2Pr1q1jlvnjvqyvrwcA3Lx5EwDw1Vdfoa+vD4sXL0ZqaioCAwNRVFSE4uJi/OlPf8KGDRukh+Q3bNgge9EjNzcXf/vb31BQUICioiKEhobi448/xr179/Dhhx/6ZHzTgbMq+cBEr+oBQGRkJJqamvDdd9/hz3/+M+7cuYO+vj4olUrExMRg8+bNyMvLQ0BAgJd77bnKykrU1dXh+++/x+PHj6FWq7FmzRq88847Yx5d+c9//oOjR49Kr2euX78ee/fu9Ztrnr/61a/Q0dGBS5cuYcEC+TeL+uO+fPo14ae5/n261NbW4sSJE9Lrma+++ip+97vfISgoSLZeb28vDh06hIsXL0qvZxYVFfntjTKA4UlEJITXPImIBDA8iYgEMDyJiAQwPImIBDA8iYgEMDyJiAQwPImIBDA8iYgEMDyJiAT8P8GBycF4SVz0AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 360x216 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "task_name = 'lipophilicity'\n",
    "tasks = ['exp']\n",
    "\n",
    "raw_filename = \"../data/Lipophilicity.csv\"\n",
    "feature_filename = raw_filename.replace('.csv','.pickle')\n",
    "filename = raw_filename.replace('.csv','')\n",
    "prefix_filename = raw_filename.split('/')[-1].replace('.csv','')\n",
    "smiles_tasks_df = pd.read_csv(raw_filename)\n",
    "smilesList = smiles_tasks_df.smiles.values\n",
    "print(\"number of all smiles: \",len(smilesList))\n",
    "atom_num_dist = []\n",
    "remained_smiles = []\n",
    "canonical_smiles_list = []\n",
    "for smiles in smilesList:\n",
    "    try:        \n",
    "        mol = Chem.MolFromSmiles(smiles)\n",
    "        atom_num_dist.append(len(mol.GetAtoms()))\n",
    "        remained_smiles.append(smiles)\n",
    "        canonical_smiles_list.append(Chem.MolToSmiles(Chem.MolFromSmiles(smiles), isomericSmiles=True))\n",
    "    except:\n",
    "        print(smiles)\n",
    "        pass\n",
    "print(\"number of successfully processed smiles: \", len(remained_smiles))\n",
    "smiles_tasks_df = smiles_tasks_df[smiles_tasks_df[\"smiles\"].isin(remained_smiles)]\n",
    "# print(smiles_tasks_df)\n",
    "smiles_tasks_df['cano_smiles'] =canonical_smiles_list\n",
    "assert canonical_smiles_list[8]==Chem.MolToSmiles(Chem.MolFromSmiles(smiles_tasks_df['cano_smiles'][8]), isomericSmiles=True)\n",
    "\n",
    "plt.figure(figsize=(5, 3))\n",
    "sns.set(font_scale=1.5)\n",
    "ax = sns.distplot(atom_num_dist, bins=28, kde=False)\n",
    "plt.tight_layout()\n",
    "# plt.savefig(\"atom_num_dist_\"+prefix_filename+\".png\",dpi=200)\n",
    "plt.show()\n",
    "plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "not processed items\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>CMPD_CHEMBLID</th>\n",
       "      <th>exp</th>\n",
       "      <th>smiles</th>\n",
       "      <th>cano_smiles</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "Empty DataFrame\n",
       "Columns: [CMPD_CHEMBLID, exp, smiles, cano_smiles]\n",
       "Index: []"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "if os.path.isfile(feature_filename):\n",
    "    feature_dicts = pickle.load(open(feature_filename, \"rb\" ))\n",
    "else:\n",
    "    feature_dicts = save_smiles_dicts(smilesList,filename)\n",
    "# feature_dicts = get_smiles_dicts(smilesList)\n",
    "remained_df = smiles_tasks_df[smiles_tasks_df[\"cano_smiles\"].isin(feature_dicts['smiles_to_atom_mask'].keys())]\n",
    "uncovered_df = smiles_tasks_df.drop(remained_df.index)\n",
    "print(\"not processed items\")\n",
    "uncovered_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "remained_df = remained_df.reset_index(drop=True)\n",
    "test_df = remained_df.sample(frac=1/10, random_state=random_seed) # test set\n",
    "training_data = remained_df.drop(test_df.index) # training data\n",
    "\n",
    "# training data is further divided into validation set and train set\n",
    "valid_df = training_data.sample(frac=1/9, random_state=random_seed) # validation set\n",
    "train_df = training_data.drop(valid_df.index) # train set\n",
    "train_df = train_df.reset_index(drop=True)\n",
    "valid_df = valid_df.reset_index(drop=True)\n",
    "test_df = test_df.reset_index(drop=True)\n",
    "\n",
    "# print(len(test_df),sorted(test_df.cano_smiles.values))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "863604\n",
      "atom_fc.weight torch.Size([200, 39])\n",
      "atom_fc.bias torch.Size([200])\n",
      "neighbor_fc.weight torch.Size([200, 49])\n",
      "neighbor_fc.bias torch.Size([200])\n",
      "GRUCell.0.weight_ih torch.Size([600, 200])\n",
      "GRUCell.0.weight_hh torch.Size([600, 200])\n",
      "GRUCell.0.bias_ih torch.Size([600])\n",
      "GRUCell.0.bias_hh torch.Size([600])\n",
      "GRUCell.1.weight_ih torch.Size([600, 200])\n",
      "GRUCell.1.weight_hh torch.Size([600, 200])\n",
      "GRUCell.1.bias_ih torch.Size([600])\n",
      "GRUCell.1.bias_hh torch.Size([600])\n",
      "align.0.weight torch.Size([1, 400])\n",
      "align.0.bias torch.Size([1])\n",
      "align.1.weight torch.Size([1, 400])\n",
      "align.1.bias torch.Size([1])\n",
      "attend.0.weight torch.Size([200, 200])\n",
      "attend.0.bias torch.Size([200])\n",
      "attend.1.weight torch.Size([200, 200])\n",
      "attend.1.bias torch.Size([200])\n",
      "mol_GRUCell.weight_ih torch.Size([600, 200])\n",
      "mol_GRUCell.weight_hh torch.Size([600, 200])\n",
      "mol_GRUCell.bias_ih torch.Size([600])\n",
      "mol_GRUCell.bias_hh torch.Size([600])\n",
      "mol_align.weight torch.Size([1, 400])\n",
      "mol_align.bias torch.Size([1])\n",
      "mol_attend.weight torch.Size([200, 200])\n",
      "mol_attend.bias torch.Size([200])\n",
      "output.weight torch.Size([1, 200])\n",
      "output.bias torch.Size([1])\n"
     ]
    }
   ],
   "source": [
    "x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array([canonical_smiles_list[0]],feature_dicts)\n",
    "num_atom_features = x_atom.shape[-1]\n",
    "num_bond_features = x_bonds.shape[-1]\n",
    "loss_function = nn.MSELoss()\n",
    "model = Fingerprint(radius, T, num_atom_features, num_bond_features,\n",
    "            fingerprint_dim, output_units_num, p_dropout)\n",
    "model.cuda()\n",
    "\n",
    "# optimizer = optim.Adam(model.parameters(), learning_rate, weight_decay=weight_decay)\n",
    "optimizer = optim.Adam(model.parameters(), 10**-learning_rate, weight_decay=10**-weight_decay)\n",
    "# optimizer = optim.SGD(model.parameters(), 10**-learning_rate, weight_decay=10**-weight_decay)\n",
    "\n",
    "# tensorboard = SummaryWriter(log_dir=\"runs/\"+start_time+\"_\"+prefix_filename+\"_\"+str(fingerprint_dim)+\"_\"+str(p_dropout))\n",
    "\n",
    "model_parameters = filter(lambda p: p.requires_grad, model.parameters())\n",
    "params = sum([np.prod(p.size()) for p in model_parameters])\n",
    "print(params)\n",
    "for name, param in model.named_parameters():\n",
    "    if param.requires_grad:\n",
    "        print(name, param.data.shape)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, dataset, optimizer, loss_function):\n",
    "    model.train()\n",
    "    np.random.seed(epoch)\n",
    "    valList = np.arange(0,dataset.shape[0])\n",
    "    #shuffle them\n",
    "    np.random.shuffle(valList)\n",
    "    batch_list = []\n",
    "    for i in range(0, dataset.shape[0], batch_size):\n",
    "        batch = valList[i:i+batch_size]\n",
    "        batch_list.append(batch)   \n",
    "    for counter, train_batch in enumerate(batch_list):\n",
    "        batch_df = dataset.loc[train_batch,:]\n",
    "        smiles_list = batch_df.cano_smiles.values\n",
    "        y_val = batch_df[tasks[0]].values\n",
    "        \n",
    "        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)\n",
    "        atoms_prediction, mol_prediction = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask))\n",
    "        \n",
    "        model.zero_grad()\n",
    "        loss = loss_function(mol_prediction, torch.Tensor(y_val).view(-1,1))     \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "def eval(model, dataset):\n",
    "    model.eval()\n",
    "    eval_MAE_list = []\n",
    "    eval_MSE_list = []\n",
    "    valList = np.arange(0,dataset.shape[0])\n",
    "    batch_list = []\n",
    "    for i in range(0, dataset.shape[0], batch_size):\n",
    "        batch = valList[i:i+batch_size]\n",
    "        batch_list.append(batch) \n",
    "    for counter, eval_batch in enumerate(batch_list):\n",
    "        batch_df = dataset.loc[eval_batch,:]\n",
    "        smiles_list = batch_df.cano_smiles.values\n",
    "#         print(batch_df)\n",
    "        y_val = batch_df[tasks[0]].values\n",
    "        \n",
    "        x_atom, x_bonds, x_atom_index, x_bond_index, x_mask, smiles_to_rdkit_list = get_smiles_array(smiles_list,feature_dicts)\n",
    "        atoms_prediction, mol_prediction = model(torch.Tensor(x_atom),torch.Tensor(x_bonds),torch.cuda.LongTensor(x_atom_index),torch.cuda.LongTensor(x_bond_index),torch.Tensor(x_mask))\n",
    "        MAE = F.l1_loss(mol_prediction, torch.Tensor(y_val).view(-1,1), reduction='none')        \n",
    "        MSE = F.mse_loss(mol_prediction, torch.Tensor(y_val).view(-1,1), reduction='none')\n",
    "#         print(x_mask[:2],atoms_prediction.shape, mol_prediction,MSE)\n",
    "        \n",
    "        eval_MAE_list.extend(MAE.data.squeeze().cpu().numpy())\n",
    "        eval_MSE_list.extend(MSE.data.squeeze().cpu().numpy())\n",
    "    return np.array(eval_MAE_list).mean(), np.array(eval_MSE_list).mean()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 2.4049506 2.3832927\n",
      "1 1.1859775 1.2447138\n",
      "2 1.1304924 1.1572254\n",
      "3 1.0605187 1.0694801\n",
      "4 1.1581542 1.1446294\n",
      "5 1.0070105 1.0034809\n",
      "6 0.98379123 0.9846309\n",
      "7 0.8907563 0.8544044\n",
      "8 0.8471567 0.81949204\n",
      "9 0.843251 0.80876637\n",
      "10 0.77015465 0.7529482\n",
      "11 0.8456124 0.83786714\n",
      "12 0.76429266 0.7905526\n",
      "13 0.74701184 0.7459793\n",
      "14 0.8022126 0.79245126\n",
      "15 0.71984875 0.7456648\n",
      "16 0.7100498 0.73237807\n",
      "17 0.6823546 0.72056365\n",
      "18 0.6711586 0.7077812\n",
      "19 0.6523204 0.6957419\n",
      "20 0.6416393 0.6903162\n",
      "21 0.6380069 0.6804075\n",
      "22 0.6401724 0.7105676\n",
      "23 0.62895536 0.691\n",
      "24 0.6036802 0.67823094\n",
      "25 0.5912554 0.6645063\n",
      "26 0.5838335 0.6800206\n",
      "27 0.5859686 0.66535527\n",
      "28 0.5800196 0.67452335\n",
      "29 0.5645402 0.6593988\n",
      "30 0.68495387 0.75959456\n",
      "31 0.58664095 0.6653512\n",
      "32 0.53799355 0.6485197\n",
      "33 0.5151322 0.62061983\n",
      "34 0.5146866 0.6233502\n",
      "35 0.5173476 0.6324485\n",
      "36 0.50005925 0.62617224\n",
      "37 0.51448786 0.65714794\n",
      "38 0.49094093 0.64261687\n",
      "39 0.49265963 0.6364139\n",
      "40 0.5068545 0.64852035\n",
      "41 0.5066594 0.65385514\n",
      "42 0.47307152 0.6281636\n",
      "43 0.48441094 0.6241973\n",
      "44 0.46148673 0.6157281\n",
      "45 0.4566079 0.6286194\n",
      "46 0.48820984 0.6658668\n",
      "47 0.44791496 0.62499976\n",
      "48 0.4180905 0.6110794\n",
      "49 0.41011438 0.61677766\n",
      "50 0.40652218 0.60951215\n",
      "51 0.3932477 0.61963904\n",
      "52 0.41370538 0.62278\n",
      "53 0.38900146 0.6307605\n",
      "54 0.38249263 0.63137275\n",
      "55 0.39164448 0.64020467\n",
      "56 0.41899917 0.6510923\n",
      "57 0.41099358 0.6278744\n",
      "58 0.3632155 0.6134981\n",
      "59 0.38056228 0.6329805\n",
      "60 0.36290747 0.62854725\n",
      "61 0.34744015 0.61421853\n",
      "62 0.35305616 0.6218957\n",
      "63 0.31901267 0.6099687\n",
      "64 0.33127937 0.6377519\n",
      "65 0.3408921 0.6134504\n",
      "66 0.3206371 0.6079148\n",
      "67 0.33201426 0.64127654\n",
      "68 0.36489865 0.63531923\n",
      "69 0.3153412 0.6208934\n",
      "70 0.31272176 0.61238915\n",
      "71 0.3175425 0.61595315\n",
      "72 0.3162924 0.6223968\n",
      "73 0.34022078 0.65672195\n",
      "74 0.28285202 0.60976285\n",
      "75 0.28596932 0.6262418\n",
      "76 0.28054526 0.59693474\n",
      "77 0.27484897 0.5994352\n",
      "78 0.28349665 0.6089215\n",
      "79 0.26419622 0.6098417\n",
      "80 0.2603183 0.60027146\n",
      "81 0.2551902 0.5946019\n",
      "82 0.24822421 0.5898403\n",
      "83 0.2439964 0.6158857\n",
      "84 0.25535578 0.61821455\n",
      "85 0.25936157 0.61194247\n",
      "86 0.26839408 0.6271016\n",
      "87 0.23975955 0.6051806\n",
      "88 0.27670372 0.6303385\n",
      "89 0.29347762 0.64066476\n",
      "90 0.24213839 0.6245132\n",
      "91 0.32193494 0.6408691\n",
      "92 0.2370674 0.61073714\n",
      "93 0.2533699 0.63283044\n",
      "94 0.24604379 0.61285895\n",
      "95 0.21734874 0.59754926\n",
      "96 0.21954466 0.60114396\n",
      "97 0.21892889 0.5977851\n",
      "98 0.23404442 0.6116714\n",
      "99 0.22917213 0.6099388\n",
      "100 0.2179149 0.6083288\n",
      "101 0.20048712 0.6010136\n",
      "102 0.23239891 0.6108931\n",
      "103 0.20812373 0.5810207\n",
      "104 0.21366884 0.60311955\n",
      "105 0.19647895 0.6052785\n",
      "106 0.19350101 0.59440595\n",
      "107 0.20290661 0.59335977\n",
      "108 0.19378991 0.6046318\n",
      "109 0.18944794 0.599456\n",
      "110 0.17604701 0.6027224\n",
      "111 0.20567027 0.5970574\n",
      "112 0.1932996 0.60084045\n",
      "113 0.1784604 0.6058249\n",
      "114 0.1806693 0.5847693\n",
      "115 0.18622671 0.60493535\n",
      "116 0.19030958 0.6229752\n",
      "117 0.19659032 0.5908643\n",
      "118 0.22107625 0.62124723\n",
      "119 0.19126216 0.6213941\n",
      "120 0.200828 0.60549873\n",
      "121 0.18057448 0.60095245\n",
      "122 0.16429934 0.5906629\n",
      "123 0.16509712 0.6064466\n",
      "124 0.18519506 0.61307883\n",
      "125 0.23898442 0.6078076\n",
      "126 0.18202212 0.60592526\n",
      "127 0.18048878 0.6040053\n",
      "128 0.19524215 0.59627646\n",
      "129 0.16004398 0.5855707\n",
      "130 0.15652551 0.59097767\n",
      "131 0.16138454 0.5817338\n",
      "132 0.16862331 0.5908905\n",
      "133 0.18007146 0.6011837\n",
      "134 0.1569539 0.61926\n",
      "135 0.15992038 0.60304826\n",
      "136 0.1712815 0.6131753\n",
      "137 0.15518016 0.5966308\n",
      "138 0.1597793 0.5764525\n",
      "139 0.15061416 0.59029555\n",
      "140 0.17691936 0.6219068\n",
      "141 0.17803456 0.59521747\n",
      "142 0.15063532 0.58722574\n",
      "143 0.1523661 0.60461086\n",
      "144 0.1464871 0.5960423\n",
      "145 0.1579151 0.5936966\n",
      "146 0.15030569 0.5929402\n",
      "147 0.14424418 0.5856528\n",
      "148 0.1567422 0.59671926\n",
      "149 0.20696597 0.6097623\n",
      "150 0.15571243 0.5858791\n",
      "151 0.15008938 0.58341545\n",
      "152 0.13767089 0.5924539\n",
      "153 0.14620863 0.59901404\n",
      "154 0.16745202 0.5861075\n",
      "155 0.14202572 0.5878751\n",
      "156 0.14645408 0.5896061\n",
      "157 0.13574576 0.5843238\n",
      "158 0.14178416 0.58851945\n",
      "159 0.14861764 0.5943505\n",
      "160 0.14865972 0.5875144\n",
      "161 0.16231185 0.6084909\n",
      "162 0.15479845 0.6041366\n",
      "163 0.21155412 0.60427994\n",
      "164 0.1481853 0.5889484\n",
      "165 0.1618286 0.6026567\n"
     ]
    }
   ],
   "source": [
    "best_param ={}\n",
    "best_param[\"train_epoch\"] = 0\n",
    "best_param[\"valid_epoch\"] = 0\n",
    "best_param[\"train_MSE\"] = 9e8\n",
    "best_param[\"valid_MSE\"] = 9e8\n",
    "\n",
    "for epoch in range(800):\n",
    "    train_MAE, train_MSE = eval(model, train_df)\n",
    "    valid_MAE, valid_MSE = eval(model, valid_df)\n",
    "#     tensorboard.add_scalars('MAE',{'train_MAE':valid_MAE, 'test_MAE':valid_MSE}, epoch)\n",
    "#     tensorboard.add_scalars('MSE',{'train_MSE':valid_MAE, 'test_MSE':valid_MSE}, epoch)\n",
    "    if train_MSE < best_param[\"train_MSE\"]:\n",
    "        best_param[\"train_epoch\"] = epoch\n",
    "        best_param[\"train_MSE\"] = train_MSE\n",
    "    if valid_MSE < best_param[\"valid_MSE\"]:\n",
    "        best_param[\"valid_epoch\"] = epoch\n",
    "        best_param[\"valid_MSE\"] = valid_MSE\n",
    "        if valid_MSE < 0.40:\n",
    "             torch.save(model, 'saved_models/model_'+prefix_filename+'_'+start_time+'_'+str(epoch)+'.pt')\n",
    "    if (epoch - best_param[\"train_epoch\"] >8) and (epoch - best_param[\"valid_epoch\"] >18):        \n",
    "        break\n",
    "    print(epoch, np.sqrt(train_MSE), np.sqrt(valid_MSE))\n",
    "    \n",
    "    train(model, train_df, optimizer, loss_function)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "best epoch: 138 \n",
      " test RMSE: 0.56426316\n"
     ]
    }
   ],
   "source": [
    "# evaluate model\n",
    "best_model = torch.load('saved_models/model_'+prefix_filename+'_'+start_time+'_'+str(best_param[\"valid_epoch\"])+'.pt')     \n",
    "\n",
    "best_model_dict = best_model.state_dict()\n",
    "best_model_wts = copy.deepcopy(best_model_dict)\n",
    "\n",
    "model.load_state_dict(best_model_wts)\n",
    "(best_model.align[0].weight == model.align[0].weight).all()\n",
    "test_MAE, test_MSE = eval(model, test_df)\n",
    "print(\"best epoch:\",best_param[\"valid_epoch\"],\"\\n\",\"test RMSE:\",np.sqrt(test_MSE))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
